# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""train DCGAN on ModelArts, get checkpoint files and air/onnx models."""
import argparse
import os
import datetime
import numpy as np
import matplotlib.pyplot as plt
import matplotlib

import mindspore.common.dtype as mstype
from mindspore import context
from mindspore import nn, Tensor, export
from mindspore.train.callback import CheckpointConfig, _InternalCallbackParam, ModelCheckpoint, RunContext
from mindspore.context import ParallelMode
from mindspore.communication.management import init, get_rank
import moxing as mox

from src.dataset import create_dataset_imagenet
from src.config import dcgan_imagenet_cfg as cfg
from src.generator import Generator
from src.discriminator import Discriminator
from src.cell import WithLossCellD, WithLossCellG
from src.dcgan import DCGAN


NORMALIZE_MEAN = 127.5
NORMALIZE_STD = 127.5

def save_imgs(gen_imgs, idx):
    """
    Save images in 4 * 4 format when training on the modelarts

    Inputs:
        - **gen_imgs** (array) - Images generated by the generator.
        - **idx** (int) - Training epoch.
    """
    matplotlib.use('Agg')
    for index in range(gen_imgs.shape[0]):
        plt.subplot(4, 4, index + 1)
        gen_imgs[index] = gen_imgs[index] * NORMALIZE_STD + NORMALIZE_MEAN
        perm = (1, 2, 0)
        show_imgs = np.transpose(gen_imgs[index], perm)
        sdf = show_imgs.astype(int)
        plt.imshow(sdf)
        plt.axis("off")
    plt.savefig("/cache/images/{}.png".format(idx))


def save_losses(G_losses_list, D_losses_list, idx):
    """
    Save Loss visualization images when training on the modelarts

    Inputs:
        - **G_losses_list** (list) - Generator loss list.
        - **D_losses_list** (list) - Discriminator loss list.
        - **idx** (int) - Training epoch.
    """
    plt.figure(figsize=(10, 5))
    plt.title("Generator and Discriminator Loss During Training")
    plt.plot(G_losses_list, label="G")
    plt.plot(D_losses_list, label="D")
    plt.xlabel("iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig("/cache/losses/{}.png".format(idx))


parser = argparse.ArgumentParser(description='MindSpore dcgan training')
parser.add_argument('--data_url', default=None,
                    help='Directory contains ImageNet-1k dataset.')
parser.add_argument('--train_url', default=None,
                    help='Directory of training output.')
parser.add_argument('--images_url', default=None,
                    help='Location of images outputs.')
parser.add_argument('--losses_url', default=None,
                    help='Location of losses outputs.')
parser.add_argument("--file_format", type=str,
                    default="AIR", help="Format of export file.")
parser.add_argument("--file_name", type=str,
                    default="dcgan", help="Output file name.")
parser.add_argument('--epoch_size', type=int,
                    default=cfg.epoch_size, help='Epoch size of training.')
args = parser.parse_args()

device_id = int(os.getenv('DEVICE_ID'))
device_num = int(os.getenv('RANK_SIZE'))
local_input_url = '/cache/data' + str(device_id)
local_output_url = '/cache/ckpt' + str(device_id)
local_images_url = '/cache/images'
local_losses_url = '/cache/losses'
context.set_context(mode=context.GRAPH_MODE,
                    device_target="Ascend", save_graphs=False)
context.set_context(device_id=device_id)

if device_num > 1:
    init()
    context.set_auto_parallel_context(device_num=device_num,
                                      global_rank=device_id,
                                      parallel_mode=ParallelMode.DATA_PARALLEL,
                                      gradients_mean=True)
    rank = get_rank()
else:
    rank = 0

mox.file.copy_parallel(src_url=args.data_url, dst_url=local_input_url)
mox.file.copy_parallel(src_url=args.images_url, dst_url=local_images_url)
mox.file.copy_parallel(src_url=args.losses_url, dst_url=local_losses_url)

if __name__ == '__main__':
    # Load Dataset
    ds = create_dataset_imagenet(os.path.join(
        local_input_url), num_parallel_workers=2)

    steps_per_epoch = ds.get_dataset_size()

    # Define Network
    netD = Discriminator()
    netG = Generator()

    criterion = nn.BCELoss(reduction='mean')

    netD_with_criterion = WithLossCellD(netD, netG, criterion)
    netG_with_criterion = WithLossCellG(netD, netG, criterion)

    optimizerD = nn.Adam(netD.trainable_params(),
                         learning_rate=cfg.learning_rate, beta1=cfg.beta1)
    optimizerG = nn.Adam(netG.trainable_params(),
                         learning_rate=cfg.learning_rate, beta1=cfg.beta1)

    myTrainOneStepCellForD = nn.TrainOneStepCell(
        netD_with_criterion, optimizerD)
    myTrainOneStepCellForG = nn.TrainOneStepCell(
        netG_with_criterion, optimizerG)

    dcgan = DCGAN(myTrainOneStepCellForD, myTrainOneStepCellForG)
    dcgan.set_train()

    # checkpoint save
    ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch,
                                   keep_checkpoint_max=args.epoch_size)
    ckpt_cb = ModelCheckpoint(
        config=ckpt_config, directory=local_output_url, prefix='dcgan')

    cb_params = _InternalCallbackParam()
    cb_params.train_network = netG
    cb_params.batch_num = steps_per_epoch
    cb_params.epoch_num = args.epoch_size
    # For each epoch
    cb_params.cur_epoch_num = 0
    cb_params.cur_step_num = 0
    run_context = RunContext(cb_params)
    ckpt_cb.begin(run_context)

    np.random.seed(1)
    fixed_noise = Tensor(np.random.normal(
        size=(16, cfg.latent_size, 1, 1)).astype("float32"))

    data_loader = ds.create_dict_iterator(
        output_numpy=True, num_epochs=args.epoch_size)
    G_losses = []
    D_losses = []
    # Start Training Loop
    print("Starting Training Loop...")
    for epoch in range(args.epoch_size):
        # For each batch in the dataloader
        for i, data in enumerate(data_loader):
            real_data = Tensor(data['image'])
            latent_code = Tensor(data["latent_code"])
            netD_loss, netG_loss = dcgan(real_data, latent_code)
            if i % 50 == 0:
                print("Date time: ", datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), "\tepoch: ", epoch, "/",
                      args.epoch_size, "\tstep: ", i, "/", steps_per_epoch, "\tDloss: ", netD_loss, "\tGloss: ",
                      netG_loss)
            D_losses.append(netD_loss.asnumpy())
            G_losses.append(netG_loss.asnumpy())
            cb_params.cur_step_num = cb_params.cur_step_num + 1
        cb_params.cur_epoch_num = cb_params.cur_epoch_num + 1
        print("================saving model===================")
        if device_id == 0:
            ckpt_cb.step_end(run_context)
            fake = netG(fixed_noise)
            print("================saving images===================")
            save_imgs(fake.asnumpy(), epoch + 1)
            print("================saving losses===================")
            save_losses(G_losses, D_losses, epoch + 1)
            mox.file.copy_parallel(
                src_url=local_images_url, dst_url=args.images_url)
            mox.file.copy_parallel(
                src_url=local_losses_url, dst_url=args.losses_url)
            mox.file.copy_parallel(
                src_url=local_output_url, dst_url=args.train_url)
        print("================success================")

    # export checkpoint file into air, onnx, mindir models
    inputs = Tensor(np.random.rand(16, 100, 1, 1), mstype.float32)
    export(netG, inputs, file_name=args.file_name,
           file_format=args.file_format)
    file_name = args.file_name + "." + args.file_format.lower()
    mox.file.copy_parallel(
        src_url=file_name, dst_url=os.path.join(args.train_url, file_name))
